# encoding: utf-8
"""
@Time   : 2018/11/13 19:57
@Author : XJH
"""

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

"""
3.1.1 准备数据
"""
train_X = np.linspace(-1, 1, 100)
train_Y = 2*train_X+np.random.randn(*train_X.shape) * 0.3
# plt.plot(train_X, train_Y, 'ro', label='Original data')
# plt.legend()
# plt.show()

"""
3.1.2 搭建模型
"""
X = tf.placeholder("float")
Y = tf.placeholder("float")
W = tf.Variable(tf.random_normal([1]), name="weight")
b = tf.Variable(tf.zeros([1]), name="bias")
z = tf.multiply(X, W) + b

cost = tf.reduce_mean(tf.square(Y-z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

"""
3.1.3 迭代训练模型
"""
init = tf.global_variables_initializer()
training_epochs = 20
display_step = 2

with tf.Session() as sess:
    sess.run(init)
    plotdata = {"batchsize": [],"loss": []}
    for epoch in range(training_epochs):
        for (x, y) in zip(train_X, train_Y):
            sess.run(optimizer, feed_dict={X: x, Y: y})

        if epoch % display_step == 0:
            loss = sess.run(cost, feed_dict={X: train_X, Y: train_Y})
            print("Epoch:", epoch+1, "cost", loss, "W=", sess.run(W), "b=", sess.run(b))
            if not (loss == "NA"):
                plotdata["batchsize"].append(epoch)
                plotdata["loss"].append(loss)

    print("Finshed!")
    print("cost=", sess.run(cost, feed_dict={X: train_X, Y: train_Y}), "W=", sess.run(W), "b=", sess.run(b))

    def moving_average(a, w=10):
        """
        平滑操作
        :param a: 
        :param w: 
        :return: 
        """
        if len(a) < w:
            return a[:]
        return [val if idx < w else sum(a[(idx-w):idx])/w for idx, val in enumerate(a)]
    plt.plot(train_X, train_Y, 'ro', label='Original data')
    plt.plot(train_X, sess.run(W) * train_X + sess.run(b), label='Fittedline')
    plt.legend()
    # plt.show()

    plotdata["avgloss"] = moving_average(plotdata["loss"])
    plt.figure(1)
    plt.subplot(211)
    plt.plot(plotdata["batchsize"], plotdata["avgloss"], 'b--')
    plt.xlabel('MiniBatch number')
    plt.ylabel('Loss')
    plt.title('Minibatch run vs. Training loss')
    plt.show()

    """
    3.1.4 使用模型
    """
    print("x=0.2, z=", sess.run(z, feed_dict={X: 0.2}))